# Tested with: Python 3.10, numpy 1.26, torch 2.2, gymnasium 0.29
# ------------------------------------------------------------
from __future__ import annotations
import math, random, copy, os, enum
from dataclasses import dataclass
from typing import Tuple, List, Any, Callable

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils import parameters_to_vector, vector_to_parameters

# Prefer gymnasium; fallback to gym if needed
try:
    import gymnasium as gym
except ImportError:
    import gym

# Make wandb optional (no-op stub if not installed)
try:
    import wandb
    WANDB_AVAILABLE = True
except Exception:
    WANDB_AVAILABLE = False

    class _WandbStub:
        def init(self, *args, **kwargs): return None
        def log(self, *args, **kwargs): pass
        def finish(self, *args, **kwargs): pass

    wandb = _WandbStub()

# ------------------------------------------------------------
# 0.  Utility
# ------------------------------------------------------------
DEVICE = torch.device('cpu')  # set to 'cuda' if desired

@dataclass
class StepT:
    s: Any
    r: float
    done: bool

def set_global_seeds(seed: int):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)

def uniform_unit_sphere(d: int) -> torch.Tensor:
    v = torch.randn(d)
    return v / (v.norm(p=2) + 1e-12)

def block_rademacher(d: int, coords: List[int]) -> torch.Tensor:
    """Sparse vector with ±1/√K on coords, 0 else."""
    v = torch.zeros(d)
    k = len(coords)
    if k == 0:
        return v
    signs = torch.randint(0, 2, (k,)) * 2 - 1  # -1 or +1
    v[coords] = signs.float() / math.sqrt(k)
    return v

# ------------------------------------------------------------
# 1.  Walker environment wrapper (continuous control)
# ------------------------------------------------------------
class WalkerEnv:
    """
    Wrapper for continuous-action walker envs. Default: BipedalWalker-v3.
    """
    def __init__(self, env_id: str = "BipedalWalker-v3", max_steps: int | None = None, seed: int = 43):
        self.env_id = env_id
        self.seed = seed
        self.max_steps = max_steps
        self.env = gym.make(self.env_id)
        self.t = 0
        _ = self._reset_compat(self.seed)
        self.obs_dim = int(self.env.observation_space.shape[0])
        self.act_dim = int(self.env.action_space.shape[0])
        self.low = np.array(self.env.action_space.low, dtype=np.float32)
        self.high = np.array(self.env.action_space.high, dtype=np.float32)

    def _reset_compat(self, seed: int | None):
        # Handles gymnasium (obs, info), new gym, and old gym variants without seed kwarg
        out = self.env.reset()
        obs = out[0] if isinstance(out, tuple) else out
        return obs

    def reset(self, seed: int | None = None):
        self.t = 0
        return self._reset_compat(seed=seed)

    def step(self, a: np.ndarray) -> StepT:
        a = np.clip(a, self.low, self.high)
        out = self.env.step(a)
        if len(out) == 5:
            obs, r, terminated, truncated, _info = out
        else:
            obs, r, done, _info = out
            terminated, truncated = done, False
        self.t += 1
        done = bool(terminated or truncated or (self.max_steps is not None and self.t >= self.max_steps))
        return StepT(obs, float(r), done)

# ------------------------------------------------------------
# 2.  Gaussian policy (tanh-squashed Normal)
# ------------------------------------------------------------
class GaussianPolicy(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden=(64, 64), log_std_init: float = -0.5):
        super().__init__()
        layers = []
        last = obs_dim
        for h in hidden:
            layers += [nn.Linear(last, h), nn.Tanh()]
            last = h
        layers += [nn.Linear(last, act_dim)]
        self.net = nn.Sequential(*layers)
        self.log_std = nn.Parameter(torch.ones(act_dim) * log_std_init)

    def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        mu = self.net(obs)
        return mu, self.log_std

    @staticmethod
    def _atanh(x: torch.Tensor) -> torch.Tensor:
        eps = 1e-6
        x = torch.clamp(x, -1 + eps, 1 - eps)
        return 0.5 * torch.log((1 + x) / (1 - x))

    def log_prob(self, obs, act_squash) -> torch.Tensor:
        if isinstance(obs, np.ndarray):
            obs = torch.tensor(obs, dtype=torch.float32)
        if isinstance(act_squash, np.ndarray):
            act_squash = torch.tensor(act_squash, dtype=torch.float32)
        if obs.dim() == 1:
            obs = obs.unsqueeze(0)
        if act_squash.dim() == 1:
            act_squash = act_squash.unsqueeze(0)

        mu, log_std = self.forward(obs)
        std = torch.exp(log_std)
        u = self._atanh(act_squash)  # pre-tanh action
        var = std.pow(2)
        log_base = -0.5 * (((u - mu) ** 2) / var + 2 * log_std + math.log(2 * math.pi))
        log_base = log_base.sum(dim=-1)
        # Tanh correction
        log_det = torch.log(1 - act_squash.pow(2) + 1e-6).sum(dim=-1)
        return (log_base - log_det).squeeze(-1)

    def act(self, obs) -> np.ndarray:
        if isinstance(obs, np.ndarray):
            obs_t = torch.tensor(obs, dtype=torch.float32)
        else:
            obs_t = obs.float()
        if obs_t.dim() == 1:
            obs_t = obs_t.unsqueeze(0)
        mu, log_std = self.forward(obs_t)
        std = torch.exp(log_std)
        u = mu + std * torch.randn_like(mu)
        a = torch.tanh(u)
        return a.squeeze(0).detach().cpu().numpy()

# ------------------------------------------------------------
# 3.  Simulated human feedback
# ------------------------------------------------------------
class PrefModel(enum.Enum):
    BT      = 'BradleyTerry'
    WEIBULL = 'Weibull'

class HumanPanel:
    def __init__(self, model:PrefModel, M:int):
        self.model = model; self.M=M

    @staticmethod
    def _sigma_bt(x):
        return 1./(1.+np.exp(-x))
    @staticmethod
    def _sigma_weibull(x):
        return math.exp(-math.exp(-x))

    def prob_prefer(self, R1:float, R0:float)->float:
        x = R1 - R0
        if self.model==PrefModel.BT:
            return self._sigma_bt(x)
        else:
            return self._sigma_weibull(x)

    def query(self, R1:float, R0:float)->Tuple[int,float]:
        p = self.prob_prefer(R1,R0)
        votes = np.random.rand(self.M) < p
        phat = votes.mean()
        majority = int(phat>=0.5)
        return majority, phat

# ------------------------------------------------------------
# 4.  Trajectory generators & helpers
# ------------------------------------------------------------
def roll_out(env:WalkerEnv, pol:GaussianPolicy, seed: int | None = None):
    s = env.reset(seed=seed)
    traj = []; R=0.0
    while True:
        a = pol.act(s)
        step = env.step(a)
        traj.append((s,a,step.r))
        R += step.r
        if step.done: break
        s = step.s
    return traj, R

def sample_pairs(policy_a, policy_b, N, env_factory: Callable[[], WalkerEnv]):
    """
    For fairness, each pair uses the same episode seed for both policies.
    """
    env = env_factory()
    for _ in range(N):
        seed = random.randint(0, 2**31 - 1)
        tau0,R0 = roll_out(env, policy_a, seed=seed)
        tau1,R1 = roll_out(env, policy_b, seed=seed)
        yield (tau0,R0, tau1,R1)

# ------------------------------------------------------------
# 5.  ZPG (Alg-1) for walker
# ------------------------------------------------------------
@dataclass
class ZPGConfig:
    T:int=200     # iterations
    N:int=2       # pref pairs per iteration
    M:int=5       # Bernoulli draws per query
    μ:float=0.05  # perturbation radius
    α:float=0.01  # step size
    link:PrefModel=PrefModel.BT
    trim:float=1e-2

class ZPG:
    def __init__(self, cfg:ZPGConfig, env_factory: Callable[[], WalkerEnv]):
        self.cfg = cfg
        probe_env = env_factory()
        self.env_factory = env_factory
        self.policy = GaussianPolicy(probe_env.obs_dim, probe_env.act_dim).to(DEVICE)
        self.panel = HumanPanel(cfg.link, cfg.M)
        with torch.no_grad():
            self.theta = parameters_to_vector(self.policy.parameters()).detach()
        self.d = self.theta.numel()
        self.total_env_steps = 0

    def _σinv(self, p:float)->float:
        p = float(np.clip(p, self.cfg.trim, 1.-self.cfg.trim))
        if self.cfg.link==PrefModel.BT:
            return math.log(p/(1-p))
        else:
            return -math.log(-math.log(p+1e-12)+1e-12)

    def run(self):
        hist = []
        for t in range(self.cfg.T):
            with torch.no_grad():
                θ = parameters_to_vector(self.policy.parameters()).detach()
            v = uniform_unit_sphere(self.d)
            θ_plus = θ + self.cfg.μ*v

            # create perturbed policy object
            pol_plus = copy.deepcopy(self.policy)
            with torch.no_grad():
                vector_to_parameters(θ_plus, pol_plus.parameters())

            # collect preferences
            gap_est = 0.0
            for tau0,R0,tau1,R1 in sample_pairs(self.policy, pol_plus, self.cfg.N, self.env_factory):
                self.total_env_steps += len(tau0) + len(tau1)
                _, phat = self.panel.query(R1,R0)
                gap_est += self._σinv(phat)
            gap_est /= max(1, self.cfg.N)

            # zeroth-order grad
            g = (self.d/self.cfg.μ) * gap_est * v
            # normalized ascent
            g = g / (g.norm() + 1e-8)
            θ_new = θ + self.cfg.α * g

            with torch.no_grad():
                vector_to_parameters(θ_new, self.policy.parameters())

            # simple evaluation
            _,R_eval = roll_out(self.env_factory(), self.policy)
            hist.append(R_eval)
            print(f"[ZPG] iter {t+1:4d}  return={R_eval:8.3f}")
            wandb.log({"return": R_eval, "ZPG_env_steps": self.total_env_steps}, step=self.total_env_steps)
        return hist

# ------------------------------------------------------------
# 6.  ZBCPG (Alg-2) for walker
# ------------------------------------------------------------
@dataclass
class ZBCPGConfig(ZPGConfig):
    K:int=256  # coordinates per block

class ZBCPG(ZPG):
    def __init__(self,cfg:ZBCPGConfig, env_factory: Callable[[], WalkerEnv]):
        super().__init__(cfg, env_factory)
        self.cfg:ZBCPGConfig = cfg

    def run(self):
        hist=[]
        for t in range(self.cfg.T):
            with torch.no_grad():
                θ = parameters_to_vector(self.policy.parameters()).detach()
            coords = random.sample(range(self.d), min(self.cfg.K, self.d))
            v = block_rademacher(self.d, coords)
            θ_plus = θ + self.cfg.μ*v

            pol_plus = copy.deepcopy(self.policy)
            with torch.no_grad():
                vector_to_parameters(θ_plus, pol_plus.parameters())

            gap_est = 0.0
            for tau0,R0,tau1,R1 in sample_pairs(self.policy, pol_plus, self.cfg.N, self.env_factory):
                self.total_env_steps += len(tau0) + len(tau1)
                _,phat = self.panel.query(R1,R0)
                gap_est += self._σinv(phat)
            gap_est /= max(1, self.cfg.N)

            g = (self.d/self.cfg.μ) * gap_est * v
            θ_new = θ + self.cfg.α * g / (g.norm() + 1e-8)
            with torch.no_grad():
                vector_to_parameters(θ_new, self.policy.parameters())

            _,R_eval = roll_out(self.env_factory(), self.policy)
            hist.append(R_eval)
            print(f"[ZBCPG] iter {t+1:4d} return={R_eval:8.3f}")
            wandb.log({"return": R_eval, "ZBCPG_env_steps": self.total_env_steps}, step=self.total_env_steps)
        return hist

# ------------------------------------------------------------
# 7.  Reward-model + PPO baseline (walker)
# ------------------------------------------------------------
class RewardMLP(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden=(64,64)):
        super().__init__()
        layers = []
        last = obs_dim + act_dim
        for h in hidden:
            layers += [nn.Linear(last, h), nn.Tanh()]
            last = h
        layers += [nn.Linear(last, 1)]
        self.net = nn.Sequential(*layers)

    def forward(self, s: torch.Tensor, a: torch.Tensor) -> torch.Tensor:
        if s.dim() == 1: s = s.unsqueeze(0)
        if a.dim() == 1: a = a.unsqueeze(0)
        x = torch.cat([s, a], dim=-1)
        return self.net(x).squeeze(-1)





class RMPPO:
    def __init__(
        self,
        panel: HumanPanel,
        env_factory: Callable[[], WalkerEnv],
        traj_pairs: int = 2,
        rm_epochs: int = 5,
        rm_lr: float = 3e-3,
        rm_batch_size: int = 32,
        # PPO hyperparams
        ppo_iters: int = 500,
        steps_per_update: int = 4096,
        ppo_epochs: int = 10,
        minibatch_size: int = 256,
        lr: float = 3e-4,
        clip_eps: float = 0.2,
        ent_coef: float = 0,
        vf_coef: float = 0.5,
        max_grad_norm: float = 1.0,
        target_kl: float = 0.01,
        # GAE
        γ: float = 0.99,
        lam: float = 0.95,
    ):
        assert panel is not None, "RMPPO requires a HumanPanel for reward-model training."
        self.panel = panel
        self.env_factory = env_factory
        self.traj_pairs = traj_pairs
        self.rm_epochs = rm_epochs
        self.rm_lr = rm_lr
        self.rm_batch_size = rm_batch_size

        probe_env = env_factory()
        self.behaviour = GaussianPolicy(probe_env.obs_dim, probe_env.act_dim).to(DEVICE)  # for RM data collection
        self.reward_net = RewardMLP(probe_env.obs_dim, probe_env.act_dim).to(DEVICE)

        # PPO policy + value
        self.policy = GaussianPolicy(probe_env.obs_dim, probe_env.act_dim).to(DEVICE)
        self.value = ValueMLP(probe_env.obs_dim).to(DEVICE)
        self.optimizer = optim.Adam(list(self.policy.parameters()) + list(self.value.parameters()), lr=lr)

        # PPO/GAE settings
        self.γ = γ
        self.lam = lam
        self.clip_eps = clip_eps
        self.ent_coef = ent_coef
        self.vf_coef = vf_coef
        self.max_grad_norm = max_grad_norm
        self.target_kl = target_kl
        self.ppo_iters = ppo_iters
        self.steps_per_update = steps_per_update
        self.ppo_epochs = ppo_epochs
        self.minibatch_size = minibatch_size

        self.total_env_steps = 0

    def _R_pred(self, traj) -> torch.Tensor:
        s = torch.tensor(np.stack([x[0] for x in traj], axis=0), dtype=torch.float32, device=DEVICE)
        a = torch.tensor(np.stack([x[1] for x in traj], axis=0), dtype=torch.float32, device=DEVICE)
        rs = self.reward_net.forward(s, a)
        return rs.sum()

    def train_rm(self):
        self.reward_net.train()
        opt = optim.Adam(self.reward_net.parameters(), lr=self.rm_lr)
        env_f = self.env_factory

        # Collect pairwise preferences (fair seeds)
        batch = []
        print("[RM] collecting data ...")
        for tau0, R0, tau1, R1 in sample_pairs(self.behaviour, self.behaviour, self.traj_pairs, env_f):
            self.total_env_steps += len(tau0) + len(tau1)
            majority, _ = self.panel.query(R1, R0)  # y in {0,1}
            batch.append((tau0, tau1, int(majority)))
        print(f"[RM] data size = {len(batch)}")

        # Train with mini-batches over pairs
        for epoch in range(self.rm_epochs):
            random.shuffle(batch)
            for start in range(0, len(batch), self.rm_batch_size):
                chunk = batch[start:start + self.rm_batch_size]
                logits = []
                targets = []
                for tau0, tau1, y in chunk:
                    R1 = self._R_pred(tau1)
                    R0 = self._R_pred(tau0)
                    logits.append((R1 - R0).view(1))
                    targets.append([float(y)])
                if not logits:
                    continue
                logits = torch.cat(logits, dim=0)  # [B]
                targets = torch.tensor(targets, dtype=torch.float32, device=DEVICE).view(-1)

                opt.zero_grad()
                loss = F.binary_cross_entropy_with_logits(logits, targets)
                loss.backward()
                opt.step()
            print(f"[RM] epoch {epoch+1} done.")

    @torch.no_grad()
    def _compute_gae(self, r: torch.Tensor, v: torch.Tensor, done: torch.Tensor):
        """
        r:    [N] float32, pseudo-reward from reward_net
        v:    [N] float32, V(s_t)
        done: [N] bool, True if terminal at s_t (the step's transition ends episode)
        Returns:
          adv: [N], ret: [N] where ret = adv + v
        """
        N = r.shape[0]
        adv = torch.zeros_like(r)
        lastgaelam = 0.0
        for t in reversed(range(N)):
            nonterminal = 0.0 if done[t].item() else 1.0
            if t == N - 1:
                next_v = 0.0
            else:
                # zero bootstrap across episode boundaries (use done[t], not done[t+1])
                next_v = v[t + 1].item() * (1.0 - float(done[t].item()))
            delta = r[t].item() + self.γ * next_v - v[t].item()
            lastgaelam = delta + self.γ * self.lam * nonterminal * lastgaelam
            adv[t] = lastgaelam
        ret = adv + v
        return adv, ret

    @torch.no_grad()
    def _collect_batch(self, env: WalkerEnv, policy: GaussianPolicy, steps_target: int):
        """
        Collects at least steps_target steps using full episodes with random seeds.
        Returns tensors on DEVICE: s [N,obs], a [N,act], r_hat [N], done [N]
        """
        states, actions, dones = [], [], []
        steps = 0
        while steps < steps_target:
            seed = random.randint(0, 2**31 - 1)
            traj, _ = roll_out(env, policy, seed=seed)
            T = len(traj)
            for i, (s, a, _r_true) in enumerate(traj):
                states.append(s)
                actions.append(a)
                dones.append(i == T - 1)
            steps += T

        s = torch.tensor(np.stack(states, axis=0), dtype=torch.float32, device=DEVICE)
        a = torch.tensor(np.stack(actions, axis=0), dtype=torch.float32, device=DEVICE)
        d = torch.tensor(dones, dtype=torch.bool, device=DEVICE)

        # Vectorized pseudo-rewards from reward model
        r_hat = self.reward_net.forward(s, a)  # [N]
        return s, a, r_hat, d

    def run(self):
        # 1) Train reward model
        self.train_rm()

        # 2) PPO on pseudo-rewards
        env = self.env_factory()
        hist = []

        for it in range(self.ppo_iters):
            # Behavior snapshot
            old_policy = copy.deepcopy(self.policy)

            # Collect batch with old_policy
            s, a, r_hat, d = self._collect_batch(env, old_policy, self.steps_per_update)
            N = s.shape[0]
            self.total_env_steps += N

            with torch.no_grad():
                v = self.value(s)
                adv, ret = self._compute_gae(r_hat, v, d)
                # Normalize advantages
                adv = (adv - adv.mean()) / (adv.std() + 1e-8)
                logp_old = old_policy.log_prob(s, a)

            # PPO epochs over minibatches
            idx_all = torch.arange(N, device=DEVICE)
            mb_size = min(self.minibatch_size, N)

            for epoch in range(self.ppo_epochs):
                perm = idx_all[torch.randperm(N, device=DEVICE)]
                for start in range(0, N, mb_size):
                    mb = perm[start:start + mb_size]

                    logp = self.policy.log_prob(s[mb], a[mb])
                    ratio = torch.exp(logp - logp_old[mb])
                    surr1 = ratio * adv[mb]
                    surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * adv[mb]
                    policy_loss = -torch.min(surr1, surr2).mean()

                    v_pred = self.value(s[mb])
                    value_loss = 0.5 * (ret[mb] - v_pred).pow(2).mean()

                    # Approximate base Normal entropy (ignores tanh transform)
                    approx_entropy = (self.policy.log_std + 0.5 * math.log(2 * math.pi * math.e)).sum()

                    loss = policy_loss + self.vf_coef * value_loss - self.ent_coef * approx_entropy

                    self.optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        list(self.policy.parameters()) + list(self.value.parameters()),
                        self.max_grad_norm
                    )
                    self.optimizer.step()

                # Early stop by target KL
                with torch.no_grad():
                    new_logp = self.policy.log_prob(s, a)
                    approx_kl = (logp_old - new_logp).mean().item()
                if approx_kl > self.target_kl:
                    break

            # Evaluate current policy (does not add to training steps)
            eval_seed = random.randint(0, 2**31 - 1)
            _, R_eval = roll_out(self.env_factory(), self.policy, seed=eval_seed)
            hist.append(R_eval)
            print(f"[RM+PPO] iter {it + 1:4d} return={R_eval:8.3f}")
            wandb.log({"return": R_eval, "RMPPO_env_steps": self.total_env_steps}, step=self.total_env_steps)

        return hist



# ------------------------------------------------------------
# 8.  DFA (pairwise DPO-like) for walker
# ------------------------------------------------------------
class DFA:
    def __init__(self,online=False,beta=1e-5, panel:HumanPanel=None,
                 N_pairs=2,iters=500, env_factory: Callable[[], WalkerEnv]=None):
        self.online=online
        self.beta=beta; self.N=N_pairs; self.iters=iters
        self.panel=panel
        self.env_factory = env_factory
        probe_env = env_factory()
        self.pol = GaussianPolicy(probe_env.obs_dim, probe_env.act_dim)
        self.ref = copy.deepcopy(self.pol)     # π₀
        self.total_env_steps=0

    def _log_prob_traj(self,tau):
        logs = []
        for s,a in [(s,a) for s,a,_ in tau]:
            logs.append(self.pol.log_prob(s, a))
        return logs

    def run(self):
        opt = optim.Adam(self.pol.parameters(), lr=4e-4)
        for t in range(self.iters):
            batch=[]
            for tau0,R0,tau1,R1 in sample_pairs(self.pol,self.pol,self.N, self.env_factory):
                self.total_env_steps+=len(tau0)+len(tau1)
                _,phat = self.panel.query(R1,R0)
                batch.append((tau0,tau1,phat))
            opt.zero_grad()
            loss=0.0
            for tau0,tau1,p in batch:
                log_tau1 = torch.stack(self._log_prob_traj(tau1)).sum()
                log_tau0 = torch.stack(self._log_prob_traj(tau0)).sum()
                logit = self.beta*(log_tau1 - log_tau0)
                loss += F.binary_cross_entropy_with_logits(
                    logit.view(1), torch.tensor([p],dtype=torch.float32))
            loss = loss/max(1,len(batch))
            loss.backward(); opt.step()

            if (t+1)%10==0:
                _,R_eval = roll_out(self.env_factory(), self.pol)
                tag = "oDPO" if self.online else "DFA"
                print(f"[{tag}:walker] iter {t+1:4d} return={R_eval:8.3f}")
                wandb.log({"return": R_eval, "DFA_env_steps": self.total_env_steps }, step=self.total_env_steps)

class ValueMLP(nn.Module):
    def __init__(self, obs_dim: int, hidden=(64, 64)):
        super().__init__()
        layers = []
        last = obs_dim
        for h in hidden:
            layers += [nn.Linear(last, h), nn.Tanh()]
            last = h
        layers += [nn.Linear(last, 1)]
        self.net = nn.Sequential(*layers)

    def forward(self, s: torch.Tensor) -> torch.Tensor:
        if s.dim() == 1:
            s = s.unsqueeze(0)
        return self.net(s).squeeze(-1)


class OraclePPO:
    def __init__(
        self,
        ppo_iters: int = 10000,
        steps_per_update: int = 4096,
        ppo_epochs: int = 10,
        minibatch_size: int = 256,
        clip_eps: float = 0.2,
        ent_coef: float = 0,
        vf_coef: float = 0.5,
        max_grad_norm: float = 1.0,
        target_kl: float = 0.01,
        lr: float = 3e-4,
        γ: float = 0.99,
        lam: float = 0.95,
        env_factory: Callable[[], WalkerEnv] = None,
    ):
        self.env_factory = env_factory
        probe_env = env_factory()
        self.policy = GaussianPolicy(probe_env.obs_dim, probe_env.act_dim).to(DEVICE)
        self.value = ValueMLP(probe_env.obs_dim).to(DEVICE)

        self.optimizer = optim.Adam(
            list(self.policy.parameters()) + list(self.value.parameters()),
            lr=lr
        )

        self.γ = γ
        self.lam = lam
        self.clip_eps = clip_eps
        self.ent_coef = ent_coef
        self.vf_coef = vf_coef
        self.max_grad_norm = max_grad_norm
        self.target_kl = target_kl
        self.ppo_iters = ppo_iters
        self.steps_per_update = steps_per_update
        self.ppo_epochs = ppo_epochs
        self.minibatch_size = minibatch_size
        self.total_env_steps = 0

    @torch.no_grad()
    def _compute_gae(self, r: torch.Tensor, v: torch.Tensor, done: torch.Tensor):
        """
        r:    [N] float32
        v:    [N] float32, V(s_t)
        done: [N] bool, True if s_t is terminal
        Returns:
          adv: [N], ret: [N] where ret = adv + v
        """
        N = r.shape[0]
        adv = torch.zeros_like(r)
        lastgaelam = 0.0

        for t in reversed(range(N)):
            nonterminal = 0.0 if done[t].item() else 1.0
            if t == N - 1:
                next_v = 0.0
            else:
                # Zero bootstrap across episode boundaries
                next_v = v[t + 1].item() * (1.0 - float(done[t].item()))
            delta = r[t].item() + self.γ * next_v - v[t].item()
            lastgaelam = delta + self.γ * self.lam * nonterminal * lastgaelam
            adv[t] = lastgaelam

        ret = adv + v
        return adv, ret

    @torch.no_grad()
    def _collect_batch(self, env: WalkerEnv, policy: GaussianPolicy, steps_target: int):
        """
        Collects at least steps_target steps by rolling out full episodes with random seeds.
        Returns tensors on DEVICE: s [N,obs], a [N,act], r [N], done [N]
        """
        states, actions, rewards, dones = [], [], [], []
        steps = 0
        while steps < steps_target:
            seed = random.randint(0, 2**31 - 1)
            traj, _ = roll_out(env, policy, seed=seed)
            T = len(traj)
            for i, (s, a, r) in enumerate(traj):
                states.append(s)
                actions.append(a)
                rewards.append(r)
                dones.append(i == T - 1)  # terminal at last step of episode
            steps += T

        s = torch.tensor(np.stack(states, axis=0), dtype=torch.float32, device=DEVICE)
        a = torch.tensor(np.stack(actions, axis=0), dtype=torch.float32, device=DEVICE)
        r = torch.tensor(rewards, dtype=torch.float32, device=DEVICE)
        d = torch.tensor(dones, dtype=torch.bool, device=DEVICE)
        return s, a, r, d

    def run(self):
        env = self.env_factory()
        hist = []

        for it in range(self.ppo_iters):
            # Snapshot behavior policy
            old_policy = copy.deepcopy(self.policy)

            # Collect a batch of transitions using old_policy
            s, a, r, d = self._collect_batch(env, old_policy, self.steps_per_update)
            N = s.shape[0]
            self.total_env_steps += N

            with torch.no_grad():
                v = self.value(s)
                adv, ret = self._compute_gae(r, v, d)
                # Normalize advantages
                adv = (adv - adv.mean()) / (adv.std() + 1e-8)
                logp_old = old_policy.log_prob(s, a)

            # PPO epochs over minibatches
            idx_all = torch.arange(N, device=DEVICE)
            mb_size = min(self.minibatch_size, N)

            for epoch in range(self.ppo_epochs):
                perm = idx_all[torch.randperm(N, device=DEVICE)]
                for start in range(0, N, mb_size):
                    mb = perm[start:start + mb_size]

                    logp = self.policy.log_prob(s[mb], a[mb])
                    ratio = torch.exp(logp - logp_old[mb])
                    surr1 = ratio * adv[mb]
                    surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * adv[mb]
                    policy_loss = -torch.min(surr1, surr2).mean()

                    v_pred = self.value(s[mb])
                    value_loss = 0.5 * (ret[mb] - v_pred).pow(2).mean()

                    # Approximate entropy of base Normal (ignoring tanh correction)
                    approx_entropy = (self.policy.log_std + 0.5 * math.log(2 * math.pi * math.e)).sum()

                    loss = policy_loss + self.vf_coef * value_loss - self.ent_coef * approx_entropy

                    self.optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        list(self.policy.parameters()) + list(self.value.parameters()),
                        self.max_grad_norm
                    )
                    self.optimizer.step()

                # Early stop by target KL
                with torch.no_grad():
                    new_logp = self.policy.log_prob(s, a)
                    approx_kl = (logp_old - new_logp).mean().item()
                if approx_kl > self.target_kl:
                    # print(f"[Oracle-PPO] early stop at epoch {epoch+1} (KL={approx_kl:.4f} > {self.target_kl})")
                    break

            # Evaluation rollout (does not count toward training steps)
            eval_seed = random.randint(0, 2**31 - 1)
            _, R_eval = roll_out(self.env_factory(), self.policy, seed=eval_seed)
            hist.append(R_eval)
            print(f"[Oracle-PPO] iter {it + 1:4d} return={R_eval:8.3f}")
            wandb.log({"return": R_eval, "OraclePPO_env_steps": self.total_env_steps},
                      step=self.total_env_steps)

        return hist
# ------------------------------------------------------------
# 10.  Main
# ------------------------------------------------------------
if __name__ == "__main__":
    set_global_seeds(3)
    method="dfa"
    ENV_ID = "Humanoid-v5"
    MAX_STEPS = 1000  # limit episode length to speed things up

    env_factory = lambda: WalkerEnv(env_id=ENV_ID, max_steps=MAX_STEPS, seed=3)
    M=5
    panel_bt = HumanPanel(PrefModel.BT, M=M)
    beta=1e-3
    rm_pairs=5
    N_pairs_dfa=5
    # Initialize wandb
    wandb.init(project="rlhf-complex", name=ENV_ID+"-"+method+"-beta:"+str(beta)+"-rm_pairs:"+str(rm_pairs)+"-N_pairs_dfa:"+str(N_pairs_dfa)+"lrdfa=4-noseed"+"M:"+str(M))
    if method=="ppo":
    # ---------- Oracle PPO (true reward) ----------------------------------
        oracle_ppo = OraclePPO(ppo_iters=1000000, env_factory=env_factory)
        oracle_ppo.run()

    elif method=="ZPG":# ---------- ZPG --------------------------------------------------------
        cfg_zpg = ZPGConfig(T=100, N=2, M=3, μ=0.05, α=0.01)
        zpg = ZPG(cfg_zpg, env_factory=env_factory); zpg.run()

    elif method=="ZBC":# ---------- ZBCPG ------------------------------------------------------
        cfg_zb = ZBCPGConfig(T=100, N=2, M=3, μ=0.05, α=0.01, K=512)
        zb = ZBCPG(cfg_zb, env_factory=env_factory); zb.run()

    elif method=="rm":# ---------- RM + PPO ---------------------------------------------------
        rmppo = RMPPO(panel=panel_bt, env_factory=env_factory, traj_pairs=rm_pairs, ppo_iters=100000)
        rmppo.run()

    # ---------- DFA --------------------------------------------------------
    elif method=="dfa":
        dfa = DFA(online=False, panel=panel_bt, iters=1000000, N_pairs=N_pairs_dfa, beta=beta, env_factory=env_factory)
        dfa.run()

    wandb.finish()